Pytorch的数据读取机制:Dataset类 & DataLoader类以及collate

您所在的位置:网站首页 dataloader worker Pytorch的数据读取机制:Dataset类 & DataLoader类以及collate

Pytorch的数据读取机制:Dataset类 & DataLoader类以及collate

2023-04-03 08:23| 来源: 网络整理| 查看: 265

写在前面

Pytorch读取数据涉及两个类:Dataset类 和 DataLoader类

Dataset类:

接收一个索引,并返回样本需要被继承,并实现 __getitem__ 和 __len__ 方法

DataLoader类:

构建可迭代的数据装载器要给定 dataset 和 batch_size(一)Dataset类

Dataset类是一个抽象类,所有自定义的数据集都需要继承这个类,所有子类都需要重写 __getitem__ 方法(获取每个数据及其对应的label),还可以重写长度类 __len__

Pytorch给出的官方代码如下:

class torch.utils.data.Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): # 子类必须继承 raise NotImplementedError def __len__(self): # 子类必须继承 raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])

Pytorch给出的官方代码限制了标准,要按照它的标准进行数据集的建立:

__getitem__ 就是接收一个索引,获取一个样本对,模型直接通过这一函数获得一对样本对 {x : y}__len__ 是指数据集长度

自己建立dataset的模板可以参考如下:

from torch.utils.data import Dataset class MyDataSet(Dataset): # 创建一个class,继承Dataset类 def __init__(self,data): # 创建初始化类,即根据这个类去创建一个实例时需要运行的函数 self.data = data # self可以把其指定的变量给后面的函数使用,相当于为整个class提供全局变量 def __getitem__(self, index): # index为编号 return self.data[index] def __len__(self): # 数据集的长度 return len(self.data)

一个读取图片数据集的例子:

from torch.utils.data import Dataset from PIL import Image #读取图片 import os #想要获得所有图片的地址,需要导入os(系统库) class MyData(Dataset): def __init__(self,root_dir,label_dir): #通过索引获取图片的地址,需要先创建图片地址的list self.root_dir=root_dir self.label_dir=label_dir self.path=os.path.join(self.root_dir,self.label_dir) self.img_path=os.listdir(self.path) #获得图片下所有的地址 def __getitem__(self, idx): #idx为编号 #获取每一个图片 img_name=self.img_path[idx] #名称 img_item_path=os.path.join(self.root_dir,self.label_dir,img_name) # 每张图片的相对路径 img=Image.open(img_item_path) #读取图片 label=self.label_dir return img,label def __len__(self): return len(self.img_path) #用类创建实例 root_dir="dataset/train" ants_label_dir="ants" bees_label_dir="bees" ants_dataset=MyData(root_dir,ants_label_dir) bees_dataset=MyData(root_dir,bees_label_dir) img, label = ants_dataset[0] img.show() # 可视化第一张图片 #将ants(124张)和bees(121张)两个数据集进行拼接 train_dataset=ants_dataset+bees_dataset

完成 Dataset 的构建后,可以通过 index 提取出一个个sample,以及对 sample 做 transformation 等等

(二)DataLoader类

有了Dataset创建的数据集后,用DataLoader函数就可以加载数据集了。很多情况下,需要进行 mini-batch 的计算,即组装成一个个小的批量,这通过 DataLoader 实现

class torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, # 每个epoch是否乱序 sampler=None, batch_sampler=None, num_workers=0, # 是否多进程读取机制,0表示在用主线程计算 collate_fn=None, # 把多个样本组合在一起变成一个mini-batch,不指定该函数的话会调用Pytorch内部默认的函数 pin_memory=False, drop_last=False, # 当样本数不能被batch_size整除时,是否舍弃最后一批数据 )可以通过 debug 了解该过程

一个打包CIFAR10数据集的小例子:

# 用torchvision提供的自定义的数据集 # CIFAR10原本是PIL Image,需要转换成tensor import torchvision.datasets from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter # 准备的测试数据集 test_data = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor()) # 加载测试集 test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False) # batch_size=4,意味着每次从test_data中取4个数据进行打包 writer = SummaryWriter("dataloader") step=0 for data in test_loader: imgs,targets = data #imgs是tensor数据类型 writer.add_images("test_data",imgs,step) step=step+1 writer.close()

用DataLoader构建一个可迭代的数据装载器,传入如何读取数据的机制Dataset,传入batch_size,就可以返回一批批的数据了。DataLoader具体使用是在模型训练的时候,由于它是一个可迭代对象,可以通过以下代码看一下一个批次的数据长啥样:

for data in test_loader: imgs, targets = data print(imgs, targets) break(三)collate_fn参数

DataLoader中的一个参数,实现自定义的batch输出。在不满意默认的 default_collate 的 batch处理结果的情况下,自己写一个collate函数来处理batch数据,以适配自己的模型数据接口

如果不设置collate_fn,我们得到的数据是 list 的形式,且所有的数据格式是 [(data1, label1), (data2, label2), (data3, label3), ......]。但若我们希望的形式是:(data1, data2, data3); (label1, label2, label3)。注意这里是一个示意,就是说我们希望从DataLoader出来的东西,所有的数据部分组织到一起,所有的label又组织到一起,它们的大小是batch_size

def collate_fn(batch): # batch是DataLoader传进来的,相当于是getitem的结果放到一个元组里,这个元组里有batch_size个元素 ([imgs,labels],...) # 自己定义怎么整理数据 real_batch = ... return real_batch # 注意在实例化DataLoader的时候要指定collate_fn参数为自己定义的

例子看这篇就行了:白渠梁:详解torch中的collate_fn参数

参考资料

Pytorch中的dataset类——创建适应任意模型的数据集接口

torch.utils.data - PyTorch 2.0 documentation

系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)_pytorch dataloader读取数据_翻滚的小@强的博客-CSDN博客

【我是土堆 - PyTorch教程】学习随手记(已更新 | 已完结 | 10w字超详细版)



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3